from __future__ import absolute_import
from __future__ import division
import time
from contextlib import ExitStack
import numpy as np
from pprint import pformat
import datetime
import os
import pandas as pd
import dill
from collections import defaultdict
import traceback
import json
import subprocess
from tabulate import tabulate
import warnings

from dps import cfg, init
from dps.utils import (
    gen_seed, time_limit, Alarm, memory_usage, gpu_memory_usage, ExperimentStore,
    ExperimentDirectory, nvidia_smi, memory_limit, Config, redirect_stream, pretty_func,
    NumpySeed, restart_tensorboard, launch_pdb_on_exception, execute_command, flush_print as _print,
)


def training_loop(exp_name='', start_time=None):
    init()

    framework = cfg.get('framework', 'tensorflow')

    if framework == 'tensorflow':
        from dps.tf.train import TensorFlowTrainingLoop
        loop = TensorFlowTrainingLoop(exp_name)

    elif framework == 'pytorch':
        from dps.pytorch.train import PyTorchTrainingLoop
        loop = PyTorchTrainingLoop(exp_name)

    else:
        raise Exception("Unknown framework: {}. Options are {'tensorflow', 'pytorch'}.".format(cfg.framework))

    return loop.run(start_time)


class EarlyStopHook:
    def __init__(self, patience, maximize, start):
        self.patience = patience
        self.maximize = maximize
        self.start = start
        self.reset()

    def _check_trigger(self, sc):
        if self._best_stopping_criteria is None:
            return True

        if self.maximize:
            return sc > self._best_stopping_criteria
        else:
            return sc < self._best_stopping_criteria

    def check(self, stopping_criteria, step, record):
        if self.start is not None and step < self.start:
            # Overwrite `stopping_criteria` if not enough steps have elapsed
            if self.maximize:
                stopping_criteria = -np.inf
            else:
                stopping_criteria = np.inf

        new_best = self._check_trigger(stopping_criteria)

        if new_best:
            self._best_stopping_criteria = stopping_criteria
            self._best_step = step
            self._best_record = record.copy()

        if self.patience > 0:
            stop_current = step - self._best_step > self.patience

            if self.start is not None:
                stop_current = stop_current and step > self.start

            self._early_stopped = self._early_stopped or stop_current

        return new_best, self._early_stopped

    @property
    def best(self):
        best = self._best_record.copy()
        best.update(stopping_criteria=self._best_stopping_criteria, local_step=self._best_step)
        return best

    def reset(self):
        self._best_stopping_criteria = None
        self._best_record = None
        self._best_step = None
        self._early_stopped = 0


class TrainingLoop:
    """ A training loop.

    The behaviour of the training loop depends on the context stack that is active when it is
    run (i.e. `run` method is called), not the one that is active when it is created.

    Parameters
    ----------
    exp_name: str
        Name of the experiment, used as a prefix when creating a directory for storing data
        generated by the training run.

    """
    def __init__(self, exp_name=''):
        self.exp_name = exp_name or cfg.exp_name
        self.start_time = None

    """ Abstract methods """

    def framework_initialize(self):
        raise Exception("NotImplemented")

    @property
    def time_remaining(self):
        if cfg.max_time is None or cfg.max_time <= 0:
            return np.inf
        else:
            elapsed_time = time.time() - self.start_time
            return cfg.max_time - elapsed_time

    def edit_remaining_stage(self, idx, stage_config):
        if len(self.curriculum_remaining) < idx+1:
            for i in range(idx+1 - len(self.curriculum_remaining)):
                self.curriculum_remaining.append(dict())

        self.curriculum_remaining[idx].update(stage_config)

    def timestamp(self, message):
        if message:
            message = message + " "

        _print("{}({}, {:.2f}s elapsed, {:.2f}s remaining)".format(
            message,
            datetime.datetime.now(),
            time.time() - self.start_time,
            self.time_remaining))

    def get_load_paths(self):
        """
        Let a *path_specification* be one of three things:
            1. An integer specifying a previous stage to load the best hypothesis from.
            2. A string of format: "stage_idx,kind" where `stage_idx` specifies a previous stage to load from
               and `kind` is either "final" or "best", specifying whether to load final or best
               hypothesis from that stage.
            3. A path on the filesystem that gives a prefix for a tensorflow checkpoint file to load from.

        Then cfg.load_path can either be a path_specification itself, in which case all variables
        in the network will be loaded from that path_specification, or a dictionary mapping from
        variable scope names to path specifications, in which case all variables in each supplied
        variable scope name will be loaded from the path_specification paired with that scope name.

        """
        load_path = cfg.load_path
        _print("\nMaybe loading weights, load_path={} ...".format(load_path))

        if load_path:
            if isinstance(load_path, str) or isinstance(load_path, int):
                load_path = {"": load_path}

            load_path = dict(load_path)

            # Sort in increasing order, so that it if one variable scope lies within another scope,
            # the outer scope gets loaded before the inner scope, rather than having the outer scope
            # wipe out the inner scope.
            items = sorted(load_path.items())

            # --- fill in paths from stages ---

            _items = []
            for module_path, path in items:
                load_stage, kind = None, None

                try:
                    load_stage = int(path)
                    kind = "best"
                except (TypeError, ValueError):
                    try:
                        split = path.split(',')
                        load_stage = int(split[0])
                        kind = 'best' if len(split) > 1 else split[1]
                        assert kind in 'best final'.split(), "path={}".format(path)
                    except Exception:
                        load_stage, kind = None, None

                if load_stage is not None:
                    if self.stage_idx == 0:
                        _print(
                            "Not loading submodule \"{}\" from stage {}, "
                            "currently in stage 0.".format(module_path, load_stage))
                        continue
                    else:
                        key = kind + '_path'
                        completed_history = self.data.history[:-1]
                        path = completed_history[load_stage][key]

                path = os.path.realpath(path)

                _items.append((module_path, path))
            return _items

        else:
            _print("`load_path` is null, using a fresh set of weights.")
            return []

    def run(self, start_time):
        """ Run the training loop.

        Parameters
        ----------
        start_time: int
            Start time (in seconds since epoch) for measuring elapsed time for
            purposes of interrupting the training loop.

        """
        if start_time is None:
            start_time = time.time()
        self.start_time = start_time

        self.timestamp("Entering TrainingLoop.run")

        # Call prepare_func to modify the config in arbitrary ways before training
        prepare_func = cfg.get("prepare_func", None)
        if callable(prepare_func):
            prepare_funcs = [prepare_func]
        else:
            try:
                prepare_funcs = list(prepare_func)
            except (TypeError, ValueError):
                prepare_funcs = []

        for f in prepare_funcs:
            if callable(f):
                _print("Calling prepare func {}...".format(pretty_func(f)))
                f()

        self.curriculum = cfg.curriculum + []

        if cfg.start_from:
            initial_stage, initial_step = cfg.start_from.split(',')
            cfg.initial_stage = initial_stage = int(initial_stage)
            initial_step = int(initial_step)
            if initial_step != 0:
                self.curriculum[initial_stage]['initial_step'] = initial_step

        if cfg.seed is None or cfg.seed < 0:
            cfg.seed = gen_seed()

        # Create a directory to store the results of the training run.
        self.experiment_store = ExperimentStore(os.path.join(cfg.local_experiments_dir, cfg.env_name))

        filename_keys = cfg.get('filename_keys', [])
        if isinstance(filename_keys, str):
            filename_keys = filename_keys.split(',')
        filename_data = {key: str(cfg[key]) for key in filename_keys if key}

        exp_dir = self.experiment_store.new_experiment(
            self.exp_name, cfg.seed, data=filename_data,
            add_date=1, force_fresh=1, update_latest=cfg.update_latest)

        self.exp_dir = exp_dir
        cfg.path = exp_dir.path

        breaker = "-" * 40
        header = "{}\nREADME.md - {}\n{}\n\n\n".format(breaker, os.path.basename(exp_dir.path), breaker)
        readme = header + (cfg.readme if cfg.readme else "") + "\n\n"

        with open(exp_dir.path_for('README.md'), 'w') as f:
            f.write(readme)

        self.data = self.training_loop_data_class(exp_dir)
        self.data.setup()

        frozen_data = None

        with ExitStack() as stack:
            if cfg.pdb:
                stack.enter_context(launch_pdb_on_exception())

                _print("`pdb` is turned on, so forcing setting robust=False")
                cfg.robust = False

            stack.enter_context(redirect_stream('stdout', self.data.path_for('stdout'), tee=cfg.tee))
            stack.enter_context(redirect_stream('stderr', self.data.path_for('stderr'), tee=cfg.tee))

            stack.enter_context(warnings.catch_warnings())
            warnings.simplefilter(cfg.warning_mode)

            _print("\n\n" + "=" * 80)
            self.timestamp("Starting training run (name={})".format(self.exp_name))

            _print("\nDirectory for this training run is {}.".format(exp_dir.path))

            stack.enter_context(NumpySeed(cfg.seed))
            _print("\nSet numpy random seed to {}.\n".format(cfg.seed))

            limiter = time_limit(
                self.time_remaining, verbose=True,
                timeout_callback=lambda limiter: _print("Training run exceeded its time limit."))

            try:
                with limiter:
                    self._run()

            finally:
                self.data.summarize()

                self.timestamp(f"Done training run (name={self.exp_name})")
                _print(f"Experiment lives at {self.exp_dir.path}")
                _print("=" * 80)
                _print("\n\n")

                frozen_data = self.data.freeze()

        finalize_func = cfg.get("finalize_func", None)
        if callable(finalize_func):
            finalize_funcs = [finalize_func]
        else:
            try:
                finalize_funcs = list(finalize_func)
            except (TypeError, ValueError):
                finalize_funcs = []
        for f in finalize_funcs:
            if callable(f):
                _print("Calling finalize func {}...".format(f.__name__))
                f()

        self.timestamp("Leaving TrainingLoop.run")

        return frozen_data

    def _run(self):
        _print(cfg)

        threshold_reached = True
        self.global_step = 0
        self.n_global_experiences = 0
        max_stages = cfg.get('max_stages', 0) or None
        self.curriculum_remaining = self.curriculum[:max_stages] + []
        self.curriculum_complete = []

        if cfg.initial_stage is not None:
            if cfg.initial_stage >= 0:
                self.stage_idx = cfg.initial_stage
            else:
                raise Exception("Initial stage cannot be negative: {}".format(cfg.initial_stage))
            self.curriculum_remaining = self.curriculum_remaining[self.stage_idx:]
        else:
            self.stage_idx = 0

        while self.curriculum_remaining:
            _print("\n" + "=" * 50)
            self.timestamp("Starting stage {}".format(self.stage_idx))
            _print("\n")

            if cfg.start_tensorboard:
                if cfg.start_tensorboard == "local":
                    tb_path = self.exp_dir.path
                else:
                    try:
                        n_latest = int(cfg.start_tensorboard)
                        tb_path = self.experiment_store.isolate_n_latest(n_latest)
                    except (ValueError, TypeError):
                        tb_path = self.experiment_store.path

                restart_tensorboard(tb_path, cfg.tbport, cfg.reload_interval)

            stage_config = self.curriculum_remaining.pop(0)
            stage_config = Config(stage_config)

            self.data.start_stage(self.stage_idx, stage_config)

            with ExitStack() as stack:

                # --------------- Stage set-up -------------------

                _print("\n" + "-" * 10 + " Stage set-up " + "-" * 10)

                _print("\nNew config values for this stage are: \n{}\n".format(stage_config))

                stack.enter_context(stage_config)

                stage_prepare_func = cfg.get("stage_prepare_func", None)
                if callable(stage_prepare_func):
                    stage_prepare_func()  # Modify the stage config in arbitrary ways before starting stage

                # Set limit on CPU RAM for the stage
                cpu_ram_limit_mb = cfg.get("cpu_ram_limit_mb", None)
                if cpu_ram_limit_mb is not None:
                    stack.enter_context(memory_limit(cfg.cpu_ram_limit_mb))

                self.framework_initialize_stage(stack)

                _print("Building env...\n")

                # Maybe build env
                if self.stage_idx == 0 or not cfg.preserve_env:
                    if getattr(self, 'env', None):
                        self.env.close()
                    self.env = cfg.build_env()

                if hasattr(self.env, "print_memory_footprint"):
                    self.env.print_memory_footprint()

                _print("\nDone building env.\n")
                _print("Building updater...\n")

                updater = cfg.get_updater(self.env)
                self.updater = updater

                updater.stage_idx = self.stage_idx
                updater.exp_dir = self.exp_dir

                updater.build_graph()
                _print("\nDone building updater.\n")

                # --- build hooks ---

                for hook in cfg.hooks:
                    assert isinstance(hook, Hook)
                    hook.start_stage(self, updater, self.stage_idx)

                if cfg.render_hook is not None:
                    cfg.render_hook.start_stage(self, updater, self.stage_idx)

                self.framework_finalize_stage_initialization()

                threshold_reached = False
                reason = None
                ran_ok = False

                try:
                    # --------------- Run stage -------------------

                    start = time.time()
                    phys_memory_before = memory_usage(physical=True)
                    gpu_memory_before = gpu_memory_usage()

                    threshold_reached, reason = self._run_stage(self.stage_idx, updater)

                    ran_ok = True

                except KeyboardInterrupt:
                    reason = "User interrupt"
                    raise

                except NotImplementedError as e:
                    # There is a bug in pdb_postmortem that prevents instances of `NotImplementedError`
                    # from being handled properly, so replace it with an instance of `Exception`.
                    if cfg.robust:
                        traceback.print_exc()
                        reason = "Exception occurred ({})".format(repr(e))
                    else:
                        raise Exception("NotImplemented") from e

                except Exception as e:
                    reason = "Exception occurred ({})".format(repr(e))
                    if cfg.robust:
                        traceback.print_exc()
                    else:
                        raise

                except Alarm:
                    reason = "Time limit exceeded"
                    raise

                finally:

                    try:
                        phys_memory_after = memory_usage(physical=True)
                        gpu_memory_after = gpu_memory_usage()

                        self.data.record_values_for_stage(
                            stage_duration=time.time()-start,
                            phys_memory_before_mb=phys_memory_before,
                            phys_memory_delta_mb=phys_memory_after - phys_memory_before,
                            gpu_memory_before_mb=gpu_memory_before,
                            gpu_memory_delta_mb=gpu_memory_after - gpu_memory_before
                        )

                        self.data.record_values_for_stage(reason=reason)

                        _print("\n" + "-" * 10 + " Optimization complete " + "-" * 10)
                        _print("\nReason: {}.\n".format(reason))

                        _print("Storing final weights...")
                        weight_start = time.time()
                        final_path = self.data.path_for('weights/final_stage_{}'.format(self.stage_idx))
                        final_path = cfg.get('save_path', final_path)
                        final_path = updater.save(final_path)
                        _print("Done saving weights, took {} seconds".format(time.time() - weight_start))

                        self.data.record_values_for_stage(final_path=final_path)

                        # --------------- Maybe test and render with best hypothesis -------------------

                        do_final_testing = (
                            "Exception occurred" not in reason
                            and reason != "Time limit exceeded"
                        )

                        if do_final_testing:
                            try:
                                _print("\n" + "-" * 10 + " Final testing/rendering " + "-" * 10)

                                if 'best_path' in self.data.current_stage_record:
                                    best_path = self.data.current_stage_record['best_path']

                                    _print("Best hypothesis for this stage was found on "
                                           "step (l: {best_local_step}, g: {best_global_step}) "
                                           "with stopping criteria ({sc_name}) of {best_stopping_criteria}.".format(
                                               sc_name=self.stopping_criteria_name, **self.data.current_stage_record))

                                    _print("Loading best hypothesis for this stage "
                                           "from file {}...".format(best_path))
                                    updater.restore(best_path)
                                else:
                                    _print("No `best_path` found, testing with final weights instead.")

                                try:
                                    test_record = updater.evaluate(cfg.batch_size, self.local_step, mode="test")
                                except Exception:
                                    _print("Encountered error running final tests: ")
                                    traceback.print_exc()

                                    test_record = {}

                                for hook in cfg.hooks:
                                    if hook.final:
                                        hook_record = hook.final_step(self, updater)

                                        if hook_record:
                                            assert len(hook_record) == 1
                                            for k, d in dict(hook_record).items():
                                                test_record.update(d)

                                self.data.record_values_for_stage(
                                    **{'test_' + k: v for k, v in test_record.items()})

                                if cfg.render_final and cfg.render_hook is not None:
                                    _print("Rendering...")
                                    cfg.render_hook(updater)
                                    _print("Done rendering.")

                                self.data.summarize()

                            except BaseException:
                                _print("Exception occurred while performing final testing/rendering: ")
                                traceback.print_exc()

                        else:
                            _print("\n" + "-" * 10 + " Skipping final testing/rendering " + "-" * 10)

                        # --------------- Finish up the stage -------------------

                        _print("\n" + "-" * 10 + " Running end-of-stage hooks " + "-" * 10 + "\n")
                        for hook in cfg.hooks:
                            hook.end_stage(self, updater, self.stage_idx)

                        self.data.end_stage(self.local_step)

                        _print()
                        self.timestamp("Done stage {}".format(self.stage_idx))
                        _print("=" * 50)

                        self.stage_idx += 1
                        self.curriculum_complete.append(stage_config)

                    except Exception:
                        # If there is already an exception, we want to post-portem as the original exception,
                        # not the one caused by finalization.
                        if ran_ok:
                            raise
                        else:
                            _print("Ignoring exception triggered while finalizing:")
                            traceback.print_exc()

                if not (threshold_reached or cfg.power_through):
                    _print("Failed to reach stopping criteria threshold on stage {} "
                           "of the curriculum, terminating.".format(self.stage_idx-1))
                    break

    def _run_stage(self, stage_idx, updater):
        """ Run main training loop for a stage of the curriculum. """

        threshold_reached = False
        reason = "NotStarted"

        stopping_criteria = cfg.stopping_criteria

        if isinstance(stopping_criteria, str):
            stopping_criteria = stopping_criteria.split(",")

        self.stopping_criteria_name = stopping_criteria[0]
        if "max" in stopping_criteria[1]:
            self.maximize_sc = True
            stopping_criteria_value = -np.inf
        elif "min" in stopping_criteria[1]:
            self.maximize_sc = False
            stopping_criteria_value = np.inf
        else:
            raise Exception("Ambiguous stopping criteria specification: {}".format(stopping_criteria[1]))

        early_stop = EarlyStopHook(
            patience=cfg.patience, maximize=self.maximize_sc, start=cfg.get('patience_start', None))

        _print("\n" + "-" * 10 + " Training begins " + "-" * 10)
        self.timestamp("")

        total_hooks_time = 0.0
        time_per_hook = 0.0

        total_eval_time = 0.0
        time_per_eval = 0.0

        total_train_time = 0.0
        time_per_example = 0.0
        time_per_update = 0.0

        n_updates = 0
        n_evals = 0
        if cfg.initial_step is not None and cfg.initial_step > 0:
            self.local_step = cfg.initial_step
        else:
            self.local_step = 0

        n_fallbacks = 0

        while True:
            local_step = self.local_step
            global_step = self.global_step

            # --- check whether to keep training ---

            if local_step >= cfg.max_steps:
                reason = "Maximum number of steps-per-stage reached"
                break

            if updater.n_experiences >= cfg.max_experiences:
                reason = "Maximum number of experiences-per-stage reached"
                break

            # --- check which steps to run ---

            render_step = cfg.eval_step if cfg.render_step <= 0 else cfg.render_step
            display_step = cfg.eval_step if cfg.display_step <= 0 else cfg.display_step
            checkpoint_step = cfg.eval_step if cfg.checkpoint_step <= 0 else cfg.checkpoint_step
            weight_step = cfg.eval_step if cfg.weight_step <= 0 else cfg.weight_step
            backup_step = cfg.eval_step if cfg.backup_step <= 0 else cfg.backup_step

            evaluate = local_step % cfg.eval_step == 0 and (local_step > 0 or cfg.get('eval_first', True))
            display = local_step % display_step == 0 and local_step > 0
            render = local_step % render_step == 0 and (local_step > 0 or cfg.render_first)
            checkpoint = local_step % checkpoint_step == 0 and local_step > 0
            save_weights = local_step % weight_step == 0 and local_step > 0
            save_weights_steps = cfg.get('save_weights_steps', [])
            save_weights |= local_step in save_weights_steps
            overwrite_weights = cfg.overwrite_weights and local_step not in save_weights_steps
            backup = local_step % backup_step == 0 and local_step > 0 and cfg.backup_dir

            if display or render or evaluate or local_step % 100 == 0:
                _print("\n{} Starting step {} {}".format("-" * 40, local_step, "-" * 40), flush=True)
                self.timestamp("")
                _print("")

            data_to_store = []

            try:
                updater.step = local_step

                # --------------- Run hooks -------------------

                hooks_start = time.time()

                for hook in cfg.hooks:
                    if hook.call_per_timestep:
                        run_hook = local_step == 0 and hook.initial
                        run_hook |= local_step > 0 and local_step % hook.n == 0

                        if run_hook:
                            hook_record = hook.step(self, updater, local_step)

                            if hook_record:
                                data_to_store.extend(dict(hook_record).items())

                hooks_duration = time.time() - hooks_start

                if render and cfg.render_hook is not None:
                    _print("Rendering...")

                    start = time.time()
                    if cfg.robust:
                        try:
                            cfg.render_hook(updater)
                        except Exception:
                            pass
                    else:
                        cfg.render_hook(updater)

                    _print("Done rendering, took {} seconds.".format(time.time() - start))

                # --------------- Possibly evaluate -------------------

                if evaluate:
                    _print("Evaluating...")
                    eval_start_time = time.time()
                    val_record = updater.evaluate(cfg.batch_size, local_step, mode="val")
                    eval_duration = time.time() - eval_start_time
                    _print("Done evaluating, took {} seconds.".format(eval_duration))

                    val_record["duration"] = eval_duration

                    n_evals += 1
                    total_eval_time += eval_duration
                    time_per_eval = total_eval_time / n_evals

                    val_record = Config(val_record)

                    data_to_store.append(("val", val_record))

                    if self.stopping_criteria_name in val_record:
                        stopping_criteria_value = val_record[self.stopping_criteria_name]
                    else:
                        stopping_criteria_names = [
                            k for k in val_record.flatten().keys() if k.startswith(self.stopping_criteria_name)]

                        if len(stopping_criteria_names) == 0:
                            _print("Stopping criteria {} not in record returned "
                                   "by updater, using 0.0.".format(self.stopping_criteria_name))
                            stopping_criteria_value = 0.0

                        elif len(stopping_criteria_names) > 1:
                            _print("stopping_criteria_name `{}` picks out multiple values: {}, using "
                                   "0.0".format(self.stopping_criteria_name, stopping_criteria_names))
                            stopping_criteria_value = 0.0
                        else:
                            stopping_criteria_value = val_record[stopping_criteria_names[0]]

                    new_best, stop = early_stop.check(stopping_criteria_value, local_step, val_record)

                    if new_best:
                        _print("Storing new best on step (l={}, g={}), "
                               "constituting (l={}, g={}) experiences, "
                               "with stopping criteria ({}) of {}.".format(
                                   local_step, global_step,
                                   updater.n_experiences, self.n_global_experiences,
                                   self.stopping_criteria_name, stopping_criteria_value))

                        best_path = self.data.path_for('weights/best_stage_{}'.format(stage_idx))
                        best_path = cfg.get('save_path', best_path)

                        weight_start = time.time()
                        best_path = updater.save(best_path)

                        _print("Done saving weights, took {} seconds".format(time.time() - weight_start))

                        self.data.record_values_for_stage(
                            best_path=best_path, best_global_step=global_step)
                        self.data.record_values_for_stage(
                            **{'best_' + k: v for k, v in early_stop.best.items()})

                    if stop:
                        _print("Early stopping triggered.")
                        reason = "Early stopping triggered"
                        break

                    threshold = cfg.get('threshold', None)
                    if threshold is not None:
                        if self.maximize_sc:
                            threshold_reached = stopping_criteria_value >= threshold
                        else:
                            threshold_reached = stopping_criteria_value <= threshold

                        if threshold_reached:
                            reason = "Stopping criteria threshold reached"
                            break

                # --------------- Perform an update -------------------

                if cfg.do_train:
                    if local_step % 100 == 0:
                        _print("Running update step {}...".format(local_step))

                    update_start_time = time.time()

                    _old_n_experiences = updater.n_experiences

                    update_record = updater.update(cfg.batch_size, local_step)

                    n_updates += 1

                    update_duration = time.time() - update_start_time
                    update_record["duration"] = update_duration

                    n_experiences_delta = updater.n_experiences - _old_n_experiences
                    self.n_global_experiences += n_experiences_delta

                    total_train_time += update_duration
                    time_per_example = total_train_time / updater.n_experiences
                    time_per_update = total_train_time / n_updates

                    total_hooks_time += hooks_duration
                    time_per_hook = total_hooks_time / n_updates

                    if local_step % 100 == 0:
                        _print("Done update step, took {} seconds.".format(update_duration))
                        _print("Average time per update: {} seconds".format(time_per_update))

                        start = time.time()
                        update_record["memory_physical_mb"] = memory_usage(physical=True)
                        update_record["memory_virtual_mb"] = memory_usage(physical=False)
                        update_record["memory_gpu_mb"] = gpu_memory_usage()
                        _print("Memory check duration: {}".format(time.time() - start))

                    if evaluate:
                        # Only store train data as often as we evaluate, otherwise it's just too much data
                        data_to_store.append(('train', update_record))

            except Exception as e:
                if not cfg.max_n_fallbacks:
                    raise e

                traceback.print_exc()

                n_fallbacks += 1

                if n_fallbacks > cfg.max_n_fallbacks:
                    _print(f"Fell back too many times ({n_fallbacks} times).")
                    raise e

                weight_dir = self.data.path_for('weights')
                weight_files = [f for f in os.listdir(weight_dir) if f.startswith(f'checkpoint_stage_{stage_idx}')]
                if not weight_files:
                    _print("Tried to fall back, but no checkpoint weights were found.")
                    raise e
                weight_file = sorted(weight_files)[-1]
                weight_path = os.path.join(weight_dir, weight_file)

                _print(f"Falling back to checkpoint weights: {weight_path}")

                updater.restore(weight_path)

                self.local_step += 1
                self.global_step += 1

                continue

            # --------------- Store data -------------------

            records = defaultdict(dict)
            for mode, r in data_to_store:
                r = Config(r).flatten()
                records[mode].update(r)

            self.data.store_step_data_and_summaries(
                stage_idx, local_step, global_step,
                updater.n_experiences, self.n_global_experiences,
                **records)

            self.data.record_values_for_stage(
                time_per_example=time_per_example,
                time_per_update=time_per_update,
                time_per_eval=time_per_eval,
                time_per_hook=time_per_hook,
                n_steps=local_step,
                n_experiences=updater.n_experiences,
            )

            if display:
                _print("Displaying...")
                self.data.summarize_current_stage(
                    local_step, global_step, updater.n_experiences, self.n_global_experiences)
                _print("\nMy PID: {}\n".format(os.getpid()))
                _print("Physical memory use: {}mb".format(memory_usage(physical=True)))
                _print("Virtual memory use: {}mb".format(memory_usage(physical=False)))

                _print("Avg time per update: {}s".format(time_per_update))
                _print("Avg time per eval: {}s".format(time_per_eval))
                _print("Avg time for hooks: {}s".format(time_per_hook))

                if cfg.use_gpu:
                    _print(nvidia_smi())

            if checkpoint:
                self.data.dump_data(local_step)

            if save_weights:
                _print("Storing checkpoint weights on step (l={}, g={}), "
                       "constituting (l={}, g={}) experiences, "
                       "with stopping criteria ({}) of {}.".format(
                           local_step, global_step,
                           updater.n_experiences, self.n_global_experiences,
                           self.stopping_criteria_name, stopping_criteria_value))

                if overwrite_weights:
                    weight_path = self.data.path_for(
                        'weights/checkpoint_stage_{}'.format(stage_idx))
                else:
                    weight_path = self.data.path_for(
                        'weights/checkpoint_stage_{}_step_{}'.format(stage_idx, local_step))

                weight_start = time.time()
                weight_path = updater.save(weight_path)
                _print(f"Saved weights to {weight_path}, took {time.time()-weight_start} seconds")

            if backup:
                _print("Backing up experiment directory.")
                _print("src: {}".format(self.exp_dir.path))
                _print("dest: {}".format(cfg.backup_dir))

                command = "rsync -avzu --timeout=300 {src} {dest}".format(
                    src=self.exp_dir.path, dest=cfg.backup_dir,
                )
                execute_command(command, output="loud", robust=True)

            self.local_step += 1
            self.global_step += 1

            # If `do_train` is False, we do no training and evaluate
            # exactly once, so only one iteration is required.
            if not cfg.do_train:
                reason = "`do_train` set to False"
                break

        return threshold_reached, reason


class FrozenTrainingLoopData(ExperimentDirectory):
    """ Interface for the on-disk data generated by a training loop.

    Parameters
    ----------
    path: str
        Path to the the directory for the experiment whose data we want to access.

    """
    def __init__(self, path):
        self.path = path.path if isinstance(path, ExperimentDirectory) else path
        self._config = None
        self._history = None

    def get_summary_path(self, mode):
        return self.path_for('summaries/' + mode, is_dir=True)

    def get_data_path(self, mode, stage_idx, local_step):
        local_path = f'data/{mode}/stage{stage_idx}/localstep={local_step}.csv'
        return self.path_for(local_path)

    def step_data(self, mode, stage_slice=None):
        stage_dirs = sorted(os.listdir(self.path_for('data/{}'.format(mode))))
        indices = [int(s[5:]) for s in stage_dirs]

        if stage_slice is None:
            pass
        elif isinstance(stage_slice, int):
            indices = [indices[stage_slice]]
        elif isinstance(stage_slice, slice):
            indices = indices[stage_slice]
        else:
            start, end, *step = stage_slice
            step = step[0] if step else 1
            indices = indices[start:end:step]

        data = {}

        for stage_idx in indices:
            local_path = 'data/{}/stage{}'.format(mode, stage_idx)
            path = self.path_for(local_path)
            files = os.listdir(path) if os.path.isdir(path) else []
            for f in files:
                local_step = float(f.split('=')[1].split('.')[0])  # Filename created by `get_data_path`
                data[(stage_idx, local_step)] = pd.read_csv(os.path.join(path, f))

        data_frames = [df for _, df in sorted(data.items())]
        if data_frames:
            return pd.concat(data_frames, axis=0, ignore_index=True)
        else:
            return None

    @property
    def config(self):
        if self._config is None:
            try:
                with open(self.path_for('config.pkl'), 'rb') as f:
                    self._config = dill.load(f)
            except Exception:
                pass
            else:
                return self._config

            try:
                with open(self.path_for('config.json'), 'r') as f:
                    self._config = json.load(f)
            except Exception:
                pass
            else:
                return self._config

        return self._config

    def get_config_value(self, key):
        if self.config is None:
            # A temporary hack to deal with version inconsistencies
            command = "grep \"'{}':\" < {}".format(key, self.path_for("config.txt"))
            p = subprocess.run(command, shell=True, stdout=subprocess.PIPE)

            # Get the line with the least amount of indentation.
            lines = p.stdout.decode().split('\n')
            lines = [l for l in lines if l.strip()]
            indentations = []
            for line in lines:
                n_leading_spaces = 0
                for c in line:
                    if c.isspace():
                        n_leading_spaces += 1
                    else:
                        break
                indentations.append((n_leading_spaces, line))
            smallest_indent = min(indentations)
            with_smallest = [(s, l) for s, l in indentations if s == smallest_indent[0]]
            assert len(with_smallest) == 1, with_smallest
            line = with_smallest[0][1]

            left, right = line.split(':')
            right = right.strip()[:-1]
            right = eval(right)
            return right
        else:
            return self.config[key]

    @property
    def n_stages(self):
        return len(self.history)

    @property
    def history(self):
        if self._history is None:
            try:
                with open(self.path_for('history.json'), 'r') as f:
                    self._history = json.load(f)
            except Exception:
                with open(self.path_for('history.pkl'), 'rb') as f:
                    self._history = dill.load(f)
        return self._history

    @property
    def modes(self):
        return os.listdir(self.path_for('summaries'))


class TrainingLoopData(FrozenTrainingLoopData):
    """ Data structure used by a TrainingLoop to manage data throughout the experiment.  """

    def setup(self):
        # Record training session environment for later diagnostic purposes
        frozen_config = cfg.freeze()
        git_mode = cfg.get('git_record_mode', 'all')
        self.record_environment(config=frozen_config, dill_recurse=True, git_mode=git_mode)
        self.curriculum = []

        self.make_directory('weights')
        self.make_directory('plots')
        self.make_directory('data')
        self.make_directory('summaries')

        self._history = []

        self.data = defaultdict(list)

        self.stage_idx = -1

        self.writers = {}

    @property
    def history(self):
        return self._history

    def start_stage(self, stage_idx, stage_config):
        self.history.append(dict(stage_idx=stage_idx, stage_config=stage_config))
        self.stage_idx = stage_idx
        self.writers = {}

    def end_stage(self, local_step=None):
        self.dump_data(local_step)

        for w in self.writers.values():
            w.close()

    def dump_data(self, local_step):
        if local_step is None:
            local_step = float("inf")  # Final dump for a stage

        for mode, data in self.data.items():
            if data:
                path = self.get_data_path(mode, self.stage_idx, local_step)

                with open(path, 'w') as f:
                    pd.DataFrame.from_records(data).to_csv(f, index=False)

                self.data[mode] = []

    def record_values_for_stage(self, d=None, **kwargs):
        """ Record values for the current stage. """
        d = d or {}
        self.current_stage_record.update(d)
        self.current_stage_record.update(kwargs)

    def store_scalar_summaries(self, mode, path, record, n_global_experiences):
        raise Exception("NotImplemented")

    def store_step_data_and_summaries(
            self, stage_idx, local_step, global_step, n_local_experiences, n_global_experiences, **data):

        for mode, record in data.items():
            if not record:
                continue

            if getattr(cfg, 'store_step_data', True):
                record = record.copy()
                record.update(
                    stage_idx=stage_idx,
                    local_step=local_step,
                    global_step=global_step,
                    n_local_experiences=n_local_experiences,
                    n_global_experiences=n_global_experiences)

                self.data[mode].append(record)

            path = self.get_summary_path(mode)
            self.store_scalar_summaries(mode, path, record, n_global_experiences)

    @property
    def current_stage_record(self):
        return self.history[-1]

    def _finalize(self):
        """ Write all stored data to disk. """
        self.dump_data(None)

        with open(self.path_for('history.json'), 'w') as f:
            json.dump(self.history, f, default=str, indent=4, sort_keys=True)

    def freeze(self):
        self._finalize()
        return FrozenTrainingLoopData(self.path)

    def summarize_current_stage(self, local_step, global_step, n_local_experiences, n_global_experiences):
        stage_idx = self.current_stage_record['stage_idx']

        print("\n{} Summary: Stage={}, Step(l={}, g={}), Experiences(l={}, g={}) {}\n".format(
            "*" * 20, stage_idx, local_step, global_step,
            n_local_experiences, n_global_experiences, "*" * 20))

        data = defaultdict(dict)

        for k, v in sorted(self.current_stage_record.items()):
            if isinstance(v, dict):
                v = "\n" + pformat(v, indent=2)
                print("* {}: {}".format(k, v))
            elif k.endswith("_path") or not k.startswith("best_"):
                print("* {}: {}".format(k, v))
            else:
                data[k[5:]]['best'] = v

        for mode, mode_data in sorted(self.data.items()):
            if mode_data:
                record = mode_data[-1] or {}
                for k, v in sorted(record.items()):
                    if isinstance(v, dict):
                        v = "\n" + pformat(v, indent=2)
                        print("* {}_{}: {}".format(mode, k, v))
                    else:
                        data[k][mode] = v

        headers = ["key", "best"] + sorted(self.data)
        table = [
            [key] + [row.get(k, None) for k in headers[1:]]
            for key, row in sorted(data.items())]

        print(tabulate(table, headers=headers, tablefmt="psql"))

    def summarize(self):
        """ Summarize the training data.

        Parameters
        ----------
        steps: quadtuple of ints
            local_step, global_step, local_experience, global_experiences

        """
        print("\n" + "-" * 30 + " Stage-by-Stage Summary " + "-" * 30 + "\n")

        table = defaultdict(dict)

        for record in self.history:
            stage_idx = record['stage_idx']
            print("\n" + "-" * 20 + " Stage {} ".format(stage_idx) + "-" * 20)

            record = Config(record).flatten()

            for k, v in sorted(record.items()):
                if isinstance(v, dict):
                    v = "\n" + pformat(v, indent=2)
                    print("* {}: {}".format(k, v))
                elif isinstance(v, str) and len(v) > 20:
                    print("* {}: {}".format(k, v))
                else:
                    table[k][stage_idx] = v

        headers = ["key"] + list(range(len(self.history)))
        table = [
            [key] + [row.get(k, None) for k in headers[1:]]
            for key, row in sorted(table.items())]

        print()
        print(tabulate(table, headers=headers, tablefmt="psql"))
        print()


class Hook:
    """ Hook called throughout training.

    Parameters
    ----------
    n: int
        Hook is called every n steps throughout training.
    initial: bool
        If True, this hook is called on the first step of a stage.
    final: bool
        If True, this hook is called at the end of stage, after loading
        the best hypothesis.

    """
    def __init__(self, n=None, initial=False, final=False):
        self.n = n
        self.initial = initial
        self.final = final

    @property
    def call_per_timestep(self):
        return bool(self.n)

    def _attrs(self):
        return "n initial final".split()

    def __str__(self):
        attr_string = ", ".join(
            "{}={}".format(k, getattr(self, k)) for k in self._attrs())
        return("{}({})".format(self.__class__.__name__, attr_string))

    def __repr__(self):
        return str(self)

    def start_stage(self, training_loop, updater, stage_idx):
        """ Called at the beginning of every stage. """
        pass

    def end_stage(self, training_loop, updater, stage_idx):
        """ Called at the end of every stage, after best hypothesis has been reloaded. """
        pass

    def step(self, training_loop, updater, step_idx):
        pass

    def final_step(self, training_loop, updater):
        """ Called during final testing for a stage. """
        pass

    def _print(self, s):
        print("{}: {}".format(self.__class__.__name__, s))
